/*
 * Written by Dawid Kurzyniec and released to the public domain, as explained
 * at http://creativecommons.org/licenses/publicdomain
 */

package edu.emory.mathcs.util.io;

import java.io.*;

/**
 * Input stream that decodes from base64 on-the-fly, that is, it converts
 * base64-encoded character stream into a decoded byte stream. It also offers
 * static methods for offline decoding, that is, converting base64-encoded strings
 * into byte arrays.
 *
 * @author Dawid Kurzyniec
 * @version 1.0
 */

public class Base64Decoder extends InputStream {

    final Reader in;
    private char[] auxCharBuf = new char[4];
    private byte[] triplet = new byte[3];
    /**
     * Number of cached bytes. Irrelevant if tripletIdx == 0.
     */
    int tripletLen = 0;
    /**
     * Index to the array of cached bytes. If 0, there are no unprocessed
     * cached bytes thus the next triplet of data needs to be read,
     * regardless of the tripletLen value.
     */
    int tripletIdx = 0;

    /**
     * Creates a new base64 decoder that reads character from a specified
     * reader.
     *
     * @param in the reader to read from
     */
    public Base64Decoder(Reader in) {
        this.in = in;
    }

    /**
     * Reads the next three bytes, if possible, to the 'dest' buffer at the
     * offset 'off'. Returns a number of bytes that were actually read,
     * between 0 and 3.
     * @param dest the destination buffer
     * @param off offset within the destination buffer
     * @return number of bytes read
     * @throws IOException
     */
    private int readTriplet(byte[] dest, int off) throws IOException {
        char[] ch = this.auxCharBuf;
        int len = in.read(ch);
        if (len < 0) return 0;
        if (len < 4) throw new StreamCorruptedException("Base64 stream abruptly terminated");

        int ch0 = ch[0];
        int ch1 = ch[1];
        int ch2 = ch[2];
        int ch3 = ch[3];

        ch0 = base64toInt(ch0);
        ch1 = base64toInt(ch1);
        dest[off++] = (byte) ((ch0 << 2) | (ch1 >> 4));
        if (ch2 == '=') {
            if (ch3 != '=') throw new StreamCorruptedException("Expected '='");
            return 1;
        }
        else {
            ch2 = base64toInt(ch2);
            dest[off++] = (byte) ((ch1 << 4) | (ch2 >> 2));
            if (ch3 == '=') {
                return 2;
            }
            else {
                ch3 = base64toInt(ch3);
                dest[off++] = (byte) ((ch2 << 6) | ch3);
                return 3;
            }
        }
    }

    public int read(byte[] buf, int off, int len) throws IOException {
        int newoff = off;
        int maxoff = off+len;
        if (tripletIdx > 0) {
            // finish off the pending triplet
            while (tripletIdx < tripletLen && newoff < maxoff) {
                buf[newoff++] = triplet[tripletIdx++];
            }
            if (tripletIdx < tripletLen) return newoff-off;
            if (tripletLen < 3) {
                // EOF hit; have we read any bytes in the above 'while'?...
                return newoff-off > 0 ? newoff-off : -1;
            }
            tripletIdx = 0;
        }
        // INV: tripletIdx == 0

        while (newoff+3<=maxoff) {
            int read = readTriplet(buf, newoff);
            newoff += read;
            if (read < 3) {
                // EOF
                return newoff-off > 0 ? newoff-off : -1;
            }
        }

        if (newoff == maxoff) return newoff-off;

        // tail
        tripletLen = readTriplet(triplet, 0);
        // INV: this loop will iterate at least once unless EOF
        // POST: tripletIdx > 0 unless EOF, in which case future reads will
        // fail with EOF on readTriplet
        while (tripletIdx < tripletLen && newoff < maxoff) {
            buf[newoff++] = triplet[tripletIdx++];
        }
        return newoff-off > 0 ? newoff-off : -1;
    }

    public int read() throws IOException {
        if (tripletIdx > 0) {
            // within Byte64 already
            if (tripletIdx < tripletLen) {
                // more buffered bytes to return
                int b = triplet[tripletIdx++];
                return b;
            }
            else if (tripletLen < 3) {
                // end of stream
                return -1;
            }
            tripletIdx = 0;
            // continue and read next triplet
        }
        // INV: tripletIdx == 0
        tripletLen = readTriplet(triplet, 0);
        if (tripletLen == 0) {
            // EOF
            return -1;
        }
        // len must be 1, 2, or 3; idx is 0
        return triplet[tripletIdx++];
    }

    public void close() throws IOException {
        in.close();
    }

    /**
     * Decodes a base64-encoded string into a byte array.
     *
     * @param base64encoded the string to decode
     * @return decoded array
     * @throws IllegalArgumentException if the string is not valid base64
     */
    public static byte[] decode(String base64encoded)
        throws IllegalArgumentException
    {
        int expectedLen = getDecodedArrayLength(base64encoded);
        byte[] result = new byte[expectedLen];
        decode(base64encoded, result, 0, expectedLen);
        return result;
    }

    /**
     * Decodes a base64-encoded string and write the result to the specified
     * byte array.
     *
     * @param base64encoded the string to decode
     * @param buf the array to write to
     * @throws IllegalArgumentException if the string is not valid base64
     */
    public static void decode(String base64encoded, byte[] buf)
        throws IllegalArgumentException
    {
        decode(base64encoded, buf, 0);
    }

    /**
     * Decodes a base64-encoded string and write the result to the specified
     * byte array at the given offset.
     *
     * @param base64encoded the string to decode
     * @param buf the array to write to
     * @param off the offset to start writing at
     * @throws IllegalArgumentException if the string is not valid base64
     */
    public static void decode(String base64encoded, byte[] buf, int off) {
        decode(base64encoded, buf, off, getDecodedArrayLength(base64encoded));
    }

    private static void decode(String base64encoded, byte[] buf, int off, int expectedLen)
        throws IllegalArgumentException
    {
        StringReader sr = new StringReader(base64encoded);
        Base64Decoder dec = new Base64Decoder(sr);
        int len = expectedLen;
        try {
            while (len > 0) {
                int read = dec.read(buf, off, len);
                if (read < 0) {
                    throw new IllegalArgumentException("Encoded array prematurely ended");
                }
                len -= read;
                off += read;
            }
        }
        catch (StreamCorruptedException e) {
            throw new IllegalArgumentException(base64encoded);
        }
        catch (IOException e) {
            // can't happen from the code of Base64Decoder and StringReader
            throw new RuntimeException();
        }
    }

    private static int getDecodedArrayLength(String base64encoded)
        throws IllegalArgumentException
    {
        int slen = base64encoded.length();
        int groups = slen/4;
        if (4*groups != slen) {
            throw new IllegalArgumentException("String length must be a multiple of four.");
        }
        int blen = 3*groups;
        if (slen != 0) {
            if (base64encoded.charAt(slen-1) == '=') blen--;
            if (base64encoded.charAt(slen-2) == '=') blen--;
        }
        return blen;
    }

    /**
     * Translates the specified character from the base 64 alphabet into
     * a 6-bit positive integer.
     *
     * @throw IllegalArgumentException or ArrayOutOfBoundsException if
     *        c is not in the base 64 Alphabet.
     */
    private static int base64toInt(int c) throws IOException {
        int result = base64ToInt[c];
        if (result < 0)
            throw new StreamCorruptedException("Illegal character " + (char)c);
        return result;
    }

    /**
     * Lookup table that from unicode base64 alphabet (as specified in RFC 2045)
     * into their 6-bit positive integer equivalents.  Characters which
     * does not belong to the base 64 alphabet are mapped to -1.
     */
    private static final byte base64ToInt[] = {
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, 52, 53, 54,
        55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, 3, 4,
        5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
        24, 25, -1, -1, -1, -1, -1, -1, 26, 27, 28, 29, 30, 31, 32, 33, 34,
        35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51
    };

}
